/* Copyright 2012 Tobias Marschall
 * 
 * This file is part of CLEVER.
 * 
 * CLEVER is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * CLEVER is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with CLEVER.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <math.h>
#include <limits>
#include <map>
#include <boost/unordered_set.hpp>

#include "CliqueWriter.h"

using namespace std;

CliqueWriter::CliqueWriter(ostream& os, VariationCaller& variation_caller, const ReadGroups* read_groups, bool multisample, bool output_all, double fdr_threshold, bool verbose) : os(os), variation_caller(variation_caller), read_groups(read_groups) {
	this->significant_ins_count = -1;
	this->significant_del_count = -1;
	this->total_count = 0;
	this->total_insertion_cliques = 0;
	this->total_deletion_cliques = 0;
	this->output_all = output_all;
	this->fdr_threshold = fdr_threshold;
	this->finished = false;
	this->read_list_os = 0;
	this->verbose = verbose;
	this->multisample = multisample;
	if (multisample) {
		assert(read_groups != 0);
	}
}

CliqueWriter::~CliqueWriter() {
	if (read_list_os != 0) {
		for (size_t i=0; i<clique_list.size(); ++i) {
			assert(clique_list[i].reads != 0);
			delete clique_list[i].reads;
		}
	}
}

void CliqueWriter::enableReadListOutput(std::ostream& os) {
	assert(!output_all);
	assert(total_count==0);
	assert(read_list_os==0);
	read_list_os = &os;
}

void CliqueWriter::callVariation(const vector<const PackedAlignmentPair*>& pairs, size_t coverage, clique_stats_t* stats) {
	assert(stats != 0);
	VariationCaller::additional_stats_t vc_stats;
	stats->variation = variation_caller.call(pairs.begin(), pairs.end(), &vc_stats);
	stats->clique_size = pairs.size();
	stats->start = vc_stats.insert_start;
	stats->end = vc_stats.insert_end;
	stats->length = vc_stats.insert_length;
	stats->diff = vc_stats.diff;
	stats->total_weight = vc_stats.total_weight;
	stats->coverage = coverage;
	stats->pvalue_corr = min(1.0, stats->variation.getPValue() * pow(2.0,static_cast<int>(stats->coverage)));
	// assume every clique to be not significant until the final FDR control
	stats->is_significant = false;
}

void CliqueWriter::add(std::auto_ptr<Clique> clique) {
	assert(!finished);
	clique_stats_t stats;
	auto_ptr<vector<const PackedAlignmentPair*> > all_pairs = clique->getAllAlignments();
	assert(all_pairs->size() == clique->size());
	if (!multisample) {
		callVariation(*all_pairs, clique->totalBreakpointCoverage(), &stats);
	} else {
		// if read groups are available, run tests for all possible combinations of read groups.
		auto_ptr<vector<size_t> > readgroup_wise_coverages = clique->readGroupWiseCoverage();
		assert(readgroup_wise_coverages.get() != 0);
		assert(readgroup_wise_coverages->size() == read_groups->size());
		// create sample wise coverage counts by summing over read groups
		auto_ptr<vector<size_t> > sample_wise_coverages(new vector<size_t>(read_groups->sampleCount(), 0));
		for (size_t i=0; i<readgroup_wise_coverages->size(); ++i) {
			(*sample_wise_coverages)[read_groups->readGroupIndexToSampleIndex(i)] += readgroup_wise_coverages->at(i);
		}
		// compute number of non-empty subsets of samples
		size_t n = (1<<read_groups->sampleCount()) - 1;
		vector<clique_stats_t> subset_wise_clique_stats(n, clique_stats_t());
		double best_pvalue = numeric_limits<double>::infinity();
		size_t best_pvalue_index = 0;
		std::vector<sample_subset_wise_stats_t> sample_subsset_wise_stats;
		// for every sample, compute probabilty that all alignments are wrong
		vector<double> sample_wrong_probs(read_groups->size(), 1.0);
		for (size_t i=0; i<all_pairs->size(); ++i) {
			assert(all_pairs->at(i) != 0);
			int rg = all_pairs->at(i)->getReadGroup();
			assert(rg >= 0);
			assert(rg < (int)read_groups->size());
			int sample = read_groups->readGroupIndexToSampleIndex(rg);
			assert(sample >= 0);
			assert(sample < (int)read_groups->sampleCount());
			sample_wrong_probs[sample] *= 1.0 - all_pairs->at(i)->getWeight();
		}
		// i is interpreted as a bitvector telling which read groups are active
		for (size_t i=1; i<=n; ++i) {
			clique_stats_t* current_stats = &subset_wise_clique_stats[i-1];
			vector<const PackedAlignmentPair*> pairs;
			for (size_t j=0; j<all_pairs->size(); ++j) {
				int rg = all_pairs->at(j)->getReadGroup();
				assert(rg != -1);
				int sample = read_groups->readGroupIndexToSampleIndex(rg);
				assert(sample >= 0);
				assert(sample < (int)read_groups->sampleCount());
				if ((i & (1<<sample)) != 0) {
					pairs.push_back(all_pairs->at(j));
				}
			}
			if (pairs.size() == 0) {
				sample_subsset_wise_stats.push_back(sample_subset_wise_stats_t(-1.0,0.0));
				continue;
			}
			double p = 1.0;
			size_t coverage = 0;
			for (size_t j=0; j<read_groups->sampleCount(); ++j) {
				if ((i & (1<<j)) != 0) {
					coverage += sample_wise_coverages->at(j);
					p *= 1.0 - sample_wrong_probs[j];
				} else {
					p *= sample_wrong_probs[j];
				}
			}
			callVariation(pairs, coverage, current_stats);
			if (current_stats->pvalue_corr < best_pvalue) {
				best_pvalue = current_stats->pvalue_corr;
				best_pvalue_index = i-1;
			}
			sample_subsset_wise_stats.push_back(sample_subset_wise_stats_t(current_stats->pvalue_corr, p));
		}
		stats = subset_wise_clique_stats[best_pvalue_index];
		stats.sample_subset_wise_stats.assign(sample_subsset_wise_stats.begin(), sample_subsset_wise_stats.end());
		stats.best_sample_combination = best_pvalue_index + 1;
		stats.sample_wise_stats.clear();
		for (size_t i=0; i<read_groups->sampleCount(); ++i) {
			stats.sample_wise_stats.push_back(sample_wise_stats_t(sample_wise_coverages->at(i)));
		}
		// gather read group information
		for (size_t i=0; i<all_pairs->size(); ++i) {
			assert(all_pairs->at(i) != 0);
			int rg = all_pairs->at(i)->getReadGroup();
			assert(rg >= 0);
			assert(rg < (int)read_groups->size());
			int sample = read_groups->readGroupIndexToSampleIndex(rg);
			assert(sample >= 0);
			assert(sample < (int)stats.sample_wise_stats.size());
			stats.sample_wise_stats[sample].add(1,all_pairs->at(i)->getWeight());
		}
	}
	switch (stats.variation.getType()) {
	case Variation::INSERTION:
		total_insertion_cliques += 1;
		break;
	case Variation::DELETION:
		total_deletion_cliques += 1;
		break;
	default:
		break;
	}
	total_count += 1;
	bool passed_fdr = stats.pvalue_corr<=fdr_threshold;
	// if read list is to be printed, we need to store the required
	// information
	if ((read_list_os!=0) && passed_fdr) {
		// retrieve all alignments associated the current clique
		auto_ptr<vector<const PackedAlignmentPair*> > alignments = clique->getAllAlignments();
		stats.reads = new vector<alignment_id_t>();
		for (size_t i=0; i<alignments->size(); ++i) {
			const PackedAlignmentPair& ap = *alignments->at(i);
			alignment_id_t as;
			// does readname already exist?
			readname_to_index_bimap_t::left_const_iterator it = readname_to_index.left.find(ap.getName());
			if (it != readname_to_index.left.end()) {
				as.read_name_idx = it->second;
			} else {
				as.read_name_idx = readname_to_index.left.size();
				readname_to_index.insert(readname_to_index_bimap_t::value_type(ap.getName(), as.read_name_idx));
			}
			as.pair_nr = ap.getPairNr();
			stats.reads->push_back(as);
		}
	}
	if (output_all) {
		os << stats << endl;
	} else {
		if (passed_fdr) {
			clique_list.push_back(stats);
		}
	}
}

void CliqueWriter::writeReadlist() {
	assert(read_list_os != 0);
	// determine the set of reads to be written.
	// "first" gives index of clique, "second" gives read index within this clique,
	// i.e. clique_list[first].reads->at(second) gives an alignment
	typedef pair<size_t,size_t> alignment_index_t;
	// Comparator to sort reads according to their name
	typedef map<size_t, vector<alignment_index_t>, readname_comparator_t> read_to_clique_idx_t;
	read_to_clique_idx_t read_to_clique_idx(readname_comparator_t(*this));
	for (size_t i=0; i<clique_list.size(); ++i) {
		const clique_stats_t& stats = clique_list[i];
		if (!stats.is_significant) continue;
		assert(stats.reads != 0);
		for (size_t j=0; j<stats.reads->size(); ++j) {
			alignment_id_t& aln_stats = stats.reads->at(j);
			// does read already exist in read_to_clique_idx?
			read_to_clique_idx_t::iterator it = read_to_clique_idx.find(aln_stats.read_name_idx);
			if (it != read_to_clique_idx.end()) {
				// if yes, just add variant index to list
				it->second.push_back(make_pair(i,j));
			} else {
				read_to_clique_idx[aln_stats.read_name_idx] = vector<alignment_index_t>();
				read_to_clique_idx[aln_stats.read_name_idx].push_back(make_pair(i,j));
			}
		}
	}
	read_to_clique_idx_t::const_iterator it = read_to_clique_idx.begin();
	for (; it!=read_to_clique_idx.end(); ++it) {
		const string& read_name = readname_to_index.right.at(it->first);
		(*read_list_os) << read_name;
		for (size_t i=0; i<it->second.size(); ++i) {
			const alignment_index_t& aln_idx = it->second[i];
			const alignment_id_t& as = clique_list[aln_idx.first].reads->at(aln_idx.second);
			(*read_list_os) << "\t" << clique_list[aln_idx.first].clique_id << "," << as.pair_nr;
		}
		(*read_list_os) << endl;
	}
}

void CliqueWriter::finish() {
	finished = true;
	if (output_all) return;
	sort(clique_list.begin(), clique_list.end(), clique_stats_comp_t());
	// perform benjamini-hochberg procedure, i.e. determine number significant
	// insertions and deletions
	significant_ins_count = 0;
	significant_del_count = 0;
	size_t insertion_count = 0;
	size_t deletion_count = 0;
	int n = 0;
	for (size_t i=0; i<clique_list.size(); ++i) {
		switch (clique_list[i].variation.getType()) {
		case Variation::INSERTION:
			insertion_count += 1;
			clique_list[i].fdr_level = clique_list[i].pvalue_corr * total_insertion_cliques / insertion_count;
			if (clique_list[i].fdr_level <= fdr_threshold) {
				significant_ins_count = i+1;
				clique_list[i].is_significant = true;
				clique_list[i].clique_id = n++;
			}
			break;
		case Variation::DELETION:
			deletion_count += 1;
			clique_list[i].fdr_level = clique_list[i].pvalue_corr * total_deletion_cliques / deletion_count;
			if (clique_list[i].fdr_level <= fdr_threshold) {
				significant_del_count = i+1;
				clique_list[i].is_significant = true;
				clique_list[i].clique_id = n++;
			}
			break;
		default:
			assert(false);
		}
	}
	for (size_t i=0; i<clique_list.size(); ++i) {
		const clique_stats_t& stats = clique_list[i];
		if (stats.is_significant) {
			if (verbose) {
				os << stats << endl;
			} else {
				os << stats.variation << endl;
			}
		}
	}
	if (read_list_os != 0) {
		writeReadlist();
	}
}

ostream& operator<<(ostream& os, const CliqueWriter::clique_stats_t& stats) {
	os << stats.variation << " " << stats.total_weight << " " << stats.clique_size << " " << stats.coverage << " " << stats.start << " " << stats.end << " " << stats.length << " " << stats.diff << " " << stats.variation.getPValue() << " " << stats.pvalue_corr;
	if (stats.fdr_level>=0) {
		os << " " << stats.fdr_level;
	}
	if (stats.sample_wise_stats.size() > 0) {
// 		os << ' ' << stats.best_readgroup_combination;
		os << ' ';
		assert(stats.best_sample_combination > 0);
		bool first = true;
		for (int i=stats.best_sample_combination, j=0; i!=0; ++j) {
			int k = 1<<j;
			if ((i & k) != 0) {
				if (!first) os << ',';
				os << j;
				i -= k;
				first = false;
			}
		}
		os << ' ';
		for (size_t i = 0; i < stats.sample_wise_stats.size(); ++i) {
			if (i>0) os << ';';
			os << stats.sample_wise_stats[i].absolute_support << ',' << stats.sample_wise_stats[i].expected_support << ',' << stats.sample_wise_stats[i].coverage;
		}
		os << ' ';
		for (size_t i = 0; i < stats.sample_subset_wise_stats.size(); ++i) {
			if (i>0) os << ';';
			os << stats.sample_subset_wise_stats[i].p_value << ',' << stats.sample_subset_wise_stats[i].probability;
		}
	}

//	if (stats.reads != 0) {
//		os << " READS:";
//		for (size_t i=0; i<stats.reads->size(); ++i) {
//			os << " (" << stats.reads->at(i).name_idx << "," << stats.reads->at(i).pair_nr << "," << stats.reads->at(i).pack() << ")";
//		}
//	}
	return os;
}
